#include "./Kalman_Filter.hpp"

/* 根据t-1时刻的状态量和t时刻控制量预测t时刻状态预测量 */
void Kalman_Filter_Function::Forecast_State_Variable(const arm_matrix_instance_f32 &A,           /* 状态转换矩阵A */
                                                     const arm_matrix_instance_f32 &sys_state,   /* t-1时刻的状态量 */
                                                     const arm_matrix_instance_f32 &B,           /* 控制转换矩阵B */
                                                     const arm_matrix_instance_f32 &input,       /* 控制量 */
                                                     arm_matrix_instance_f32 &sys_state_forecast /* t时刻的状态预测量 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 计算A*sys_state矩阵的大小 */
    row = A.numRows;
    column = sys_state.numCols;

    /* 分配并计算A*sys_statet矩阵 */
    float A_mul_sys_state_data[row * column];
    arm_matrix_instance_f32 A_mul_sys_state_matrix;
    arm_mat_init_f32(&A_mul_sys_state_matrix, row, column, A_mul_sys_state_data);
    arm_mat_mult_f32(&A, &sys_state, &A_mul_sys_state_matrix);

    Kalman_Filter_Function::Forecast_State_Variable(A_mul_sys_state_matrix, B, input, sys_state_forecast); /* 计算状态预测量 */
}

/* 更新预测协方差矩阵 */
void Kalman_Filter_Function::Forecast_Covariance(const arm_matrix_instance_f32 A,             /* 状态转换矩阵A */
                                                 const arm_matrix_instance_f32 &covariance,   /* t-1时刻的协方差矩阵 */
                                                 const arm_matrix_instance_f32 &R,            /* 控制噪声矩阵R */
                                                 arm_matrix_instance_f32 &covariance_forecast /* t时刻的协方差预测矩阵 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 分配并计算A*Sigma*A' */
    row = A.numRows;
    column = A.numRows;
    float A_mul_Sigma_mul_A_Trans_data[row * column];
    arm_matrix_instance_f32 A_mul_Sigma_mul_A_Trans_matrix;
    arm_mat_init_f32(&A_mul_Sigma_mul_A_Trans_matrix, row, column, A_mul_Sigma_mul_A_Trans_data);
    Kalman_Filter_Function::Cal_A_mul_B_A_Trans(A, covariance, A_mul_Sigma_mul_A_Trans_matrix);

    Kalman_Filter_Function::Forecast_Covariance(A_mul_Sigma_mul_A_Trans_matrix, R, covariance_forecast);
}

/* 计算卡尔曼增益矩阵 */
void Kalman_Filter_Function::Cal_Kalman_Gain(const arm_matrix_instance_f32 &covariance_forecast, /* t时刻的协方差预测矩阵 */
                                             const arm_matrix_instance_f32 &C,                   /* 测量转换矩阵C */
                                             const arm_matrix_instance_f32 &Q,                   /* 测量噪声矩阵Q */
                                             arm_matrix_instance_f32 &K                          /* 卡尔曼增益矩阵 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 分配并计算C*Sigma*C' */
    row = C.numRows;
    column = C.numRows;
    float C_mul_Sigma_mul_C_Trans_data[row * column];
    arm_matrix_instance_f32 C_mul_Sigma_mul_C_Trans_matrix;
    arm_mat_init_f32(&C_mul_Sigma_mul_C_Trans_matrix, row, column, C_mul_Sigma_mul_C_Trans_data);
    Kalman_Filter_Function::Cal_A_mul_B_A_Trans(C, covariance_forecast, C_mul_Sigma_mul_C_Trans_matrix);

    /* 分配并计C' */
    row = C.numCols;
    column = C.numRows;
    float C_Trans_data[row * column];
    arm_matrix_instance_f32 C_Trans_matrix;
    arm_mat_init_f32(&C_Trans_matrix, row, column, C_Trans_data);
    arm_mat_trans_f32(&C, &C_Trans_matrix);

    /* 分配并计Sigma*C' */
    row = covariance_forecast.numRows;
    column = C_Trans_matrix.numCols;
    float Sigma_mul_C_Trans_data[row * column];
    arm_matrix_instance_f32 Sigma_mul_C_Trans_matrix;
    arm_mat_init_f32(&Sigma_mul_C_Trans_matrix, row, column, Sigma_mul_C_Trans_data);
    arm_mat_mult_f32(&covariance_forecast, &C_Trans_matrix, &Sigma_mul_C_Trans_matrix);

    /* 分配并计算C*Sigma*C'+Q */
    float C_mul_Sigma_mul_C_Trans_add_Q_data[row * column];
    arm_matrix_instance_f32 C_mul_Sigma_mul_C_Trans_add_Q_matrix;
    arm_mat_init_f32(&C_mul_Sigma_mul_C_Trans_add_Q_matrix, row, column, C_mul_Sigma_mul_C_Trans_add_Q_data);
    arm_mat_add_f32(&C_mul_Sigma_mul_C_Trans_matrix, &Q, &C_mul_Sigma_mul_C_Trans_add_Q_matrix);

    /* 分配并计算C_mul_Sigma_mul_C_Trans_add_Q_matrix^-1 */
    row = C_mul_Sigma_mul_C_Trans_add_Q_matrix.numCols;
    column = C_mul_Sigma_mul_C_Trans_add_Q_matrix.numRows;
    float C_mul_Sigma_mul_C_Trans_add_Q_inv_data[row * column];
    arm_matrix_instance_f32 C_mul_Sigma_mul_C_Trans_add_Q_inv_matrix;
    arm_mat_init_f32(&C_mul_Sigma_mul_C_Trans_add_Q_inv_matrix, row, column, C_mul_Sigma_mul_C_Trans_add_Q_inv_data);

    arm_status mat_result = arm_mat_inverse_f32(&C_mul_Sigma_mul_C_Trans_add_Q_matrix, &C_mul_Sigma_mul_C_Trans_add_Q_inv_matrix);

    if (ARM_MATH_SINGULAR != mat_result)
    {
        /* 计算卡尔曼增益矩阵 */
        arm_mat_mult_f32(&Sigma_mul_C_Trans_matrix, &C_mul_Sigma_mul_C_Trans_add_Q_inv_matrix, &K);
    }
}

/* 使用测量值更新状态量 */
void Kalman_Filter_Function::Update_State_Variable(const arm_matrix_instance_f32 &sys_state_forecast, /* t时刻的状态预测量 */
                                                   const arm_matrix_instance_f32 &K,                  /* 卡尔曼增益矩阵 */
                                                   const arm_matrix_instance_f32 &measurement,        /* t时刻测量量 */
                                                   const arm_matrix_instance_f32 &C,                  /* 测量转换矩阵C */
                                                   arm_matrix_instance_f32 &sys_state                 /* t时刻的状态量 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 分配并计算C*sys_state_forecast矩阵 */
    row = C.numRows;
    column = sys_state_forecast.numCols;
    float C_mul_sys_state_forecast_data[row * column];
    arm_matrix_instance_f32 C_mul_sys_state_forecast_matrix;
    arm_mat_init_f32(&C_mul_sys_state_forecast_matrix, row, column, C_mul_sys_state_forecast_data);
    arm_mat_mult_f32(&C, &sys_state_forecast, &C_mul_sys_state_forecast_matrix);

    Kalman_Filter_Function::Update_State_Variable(C_mul_sys_state_forecast_matrix, K, measurement, sys_state);
}

/* 更新协方差 */
void Kalman_Filter_Function::Update_Covariance(const arm_matrix_instance_f32 &covariance_forecast, /* t时刻的协方差预测矩阵 */
                                               const arm_matrix_instance_f32 &I,                   /* 单位矩阵 */
                                               const arm_matrix_instance_f32 &K,                   /* 卡尔曼增益矩阵 */
                                               const arm_matrix_instance_f32 &C,                   /* 测量转换矩阵C */
                                               arm_matrix_instance_f32 &covariance                 /* t时刻的协方差矩阵 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 分配并计算K*C矩阵 */
    row = K.numRows;
    column = C.numCols;
    float K_mul_C_data[row * column];
    arm_matrix_instance_f32 K_mul_C_matrix;
    arm_mat_init_f32(&K_mul_C_matrix, row, column, K_mul_C_data);
    arm_mat_mult_f32(&K, &C, &K_mul_C_matrix);

    Kalman_Filter_Function::Update_Covariance(covariance_forecast, I, K_mul_C_matrix, covariance);
}

/**
 * @description: 计算正定矩阵 output=A*B*A'
 * @param {type} 
 * @return: 
 */
void Kalman_Filter_Function::Cal_A_mul_B_A_Trans(const arm_matrix_instance_f32 &A, const arm_matrix_instance_f32 &B, arm_matrix_instance_f32 &output)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 分配并计算A*B矩阵 */
    row = A.numRows;
    column = B.numCols;
    float A_mul_B_data[row * column];
    arm_matrix_instance_f32 A_mul_B_matrix;
    arm_mat_init_f32(&A_mul_B_matrix, row, column, A_mul_B_data);
    arm_mat_mult_f32(&A, &B, &A_mul_B_matrix);

    /* 分配并计算A'矩阵 */
    row = A.numCols;
    column = A.numRows;
    float A_Trans_data[row * column];
    arm_matrix_instance_f32 A_Trans_matrix;
    arm_mat_init_f32(&A_Trans_matrix, row, column, A_Trans_data);
    arm_mat_trans_f32(&A, &A_Trans_matrix);

    /* 分配并计算A*B*A'矩阵 */
    arm_mat_mult_f32(&A_mul_B_matrix, &A_Trans_matrix, &output);
}

/* A为单位矩阵时，根据t-1时刻的状态量和t时刻控制量预测t时刻状态预测量 */
void Kalman_Filter_Function::Forecast_State_Variable(const arm_matrix_instance_f32 &sys_state,   /* t-1时刻的状态量 */
                                                     const arm_matrix_instance_f32 &B,           /* 控制转换矩阵B */
                                                     const arm_matrix_instance_f32 &input,       /* 控制量 */
                                                     arm_matrix_instance_f32 &sys_state_forecast /* t时刻的状态预测量 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 计算B*input矩阵的大小 */
    row = B.numRows;
    column = input.numCols;

    /* 分配并计算B*input矩阵 */
    float B_mul_input_data[row * column];
    arm_matrix_instance_f32 B_mul_input_matrix;
    arm_mat_init_f32(&B_mul_input_matrix, row, column, B_mul_input_data);
    arm_mat_mult_f32(&B, &input, &B_mul_input_matrix);

    /* 计算预测量 */
    arm_mat_add_f32(&sys_state, &B_mul_input_matrix, &sys_state_forecast);
}

/* C为单位矩阵时，计算卡尔曼增益矩阵 */
void Kalman_Filter_Function::Cal_Kalman_Gain(const arm_matrix_instance_f32 &covariance_forecast, /* t时刻的协方差预测矩阵 */
                                             const arm_matrix_instance_f32 &Q,                   /* 测量噪声矩阵Q */
                                             arm_matrix_instance_f32 &K                          /* 卡尔曼增益矩阵 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    row = covariance_forecast.numRows;
    column = covariance_forecast.numCols;

    /* 分配并计算Sigma+Q */
    float Sigma_add_Q_data[row * column];
    arm_matrix_instance_f32 Sigma_add_Q_matrix;
    arm_mat_init_f32(&Sigma_add_Q_matrix, row, column, Sigma_add_Q_data);
    arm_mat_add_f32(&covariance_forecast, &Q, &Sigma_add_Q_matrix);

    /* 分配并计算Sigma_add_Q_matrix^-1 */
    row = Sigma_add_Q_matrix.numCols;
    column = Sigma_add_Q_matrix.numRows;
    float Sigma_add_Q_inv_data[row * column];
    arm_matrix_instance_f32 Sigma_add_Q_inv_matrix;
    arm_mat_init_f32(&Sigma_add_Q_inv_matrix, row, column, Sigma_add_Q_inv_data);
    arm_mat_inverse_f32(&Sigma_add_Q_matrix, &Sigma_add_Q_inv_matrix);

    /* 计算卡尔曼增益矩阵 */
    arm_mat_mult_f32(&covariance_forecast, &Sigma_add_Q_inv_matrix, &K);
}

/* C为单位矩阵时，使用测量值更新状态量 */
void Kalman_Filter_Function::Update_State_Variable(const arm_matrix_instance_f32 &sys_state_forecast, /* t时刻的状态预测量 */
                                                   const arm_matrix_instance_f32 &K,                  /* 卡尔曼增益矩阵 */
                                                   const arm_matrix_instance_f32 &measurement,        /* t时刻测量量 */
                                                   arm_matrix_instance_f32 &sys_state                 /* t时刻的状态量 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    row = measurement.numRows;
    column = measurement.numCols;

    /* 分配并计算Z-sys_state_forecast矩阵 */
    float Z_Sub_sys_state_forecast_data[row * column];
    arm_matrix_instance_f32 Z_Sub_sys_state_forecast_matrix;
    arm_mat_init_f32(&Z_Sub_sys_state_forecast_matrix, row, column, Z_Sub_sys_state_forecast_data);
    arm_mat_sub_f32(&measurement, &sys_state_forecast, &Z_Sub_sys_state_forecast_matrix);

    /* 分配并计算K*Z_Sub_sys_state_forecast_matrix */
    row = K.numRows;
    column = Z_Sub_sys_state_forecast_matrix.numCols;
    float K_mul_Z_Sub_sys_state_forecast_data[row * column];
    arm_matrix_instance_f32 K_mul_Z_Sub_sys_state_forecast_matrix;
    arm_mat_init_f32(&K_mul_Z_Sub_sys_state_forecast_matrix, row, column, K_mul_Z_Sub_sys_state_forecast_data);
    arm_mat_mult_f32(&K, &Z_Sub_sys_state_forecast_matrix, &K_mul_Z_Sub_sys_state_forecast_matrix);

    /* 更新状态量 */
    arm_mat_add_f32(&sys_state_forecast, &K_mul_Z_Sub_sys_state_forecast_matrix, &sys_state);
}

/* C为单位矩阵时，更新协方差 */
void Kalman_Filter_Function::Update_Covariance(const arm_matrix_instance_f32 &covariance_forecast, /* t时刻的协方差预测矩阵 */
                                               const arm_matrix_instance_f32 &I,                   /* 单位矩阵 */
                                               const arm_matrix_instance_f32 &K,                   /* 卡尔曼增益矩阵 */
                                               arm_matrix_instance_f32 &covariance                 /* t时刻的协方差矩阵 */
)
{
    uint16_t row;    /* 行 */
    uint16_t column; /* 列 */

    /* 分配并计算I-K */
    row = I.numRows;
    column = I.numCols;
    float I_sub_K_data[row * column];
    arm_matrix_instance_f32 I_sub_K_matrix;
    arm_mat_init_f32(&I_sub_K_matrix, row, column, I_sub_K_data);
    arm_mat_sub_f32(&I, &K, &I_sub_K_matrix);

    /* 更新协方差 */
    arm_mat_mult_f32(&I_sub_K_matrix, &covariance_forecast, &covariance);
}
